import numpy as np
import pandas as pd
# from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler


def split_sample_cv(G_column, X_columns, Y_column, df, train_index, test_index,
                    standardize=False):
    """split the sample into train and test sets for cross-validation
    
    Parameters
    ----------
    G_column : str
        group variable
    X_columns : list of str
        covariates
    Y_column : str
        outcome variable
    df : pandas.DataFrame
        whole dataset
    train_index : numpy.ndarray
        indices of the training set
    test_index : numpy.ndarray
        indices of the test set
    standardize : bool, optional
        whether to standardize the covariates, by default False

    Returns
    -------
    X_train_full, X_test, y_train_full, y_test
    X_test_blue, y_test_blue, X_test_red, y_test_red
    train
    
    Notes
    -----
    Regarding the objects returned:
    variables starting from y are numpy arrays
    other variables are pandas dataframes
    """
    
    train, test = df.iloc[train_index].copy(), df.iloc[test_index].copy()
    ## March 24, 2024: previously I wrote the following line, but it was slow
    # train, test = df.loc[train_index].copy(), df.loc[test_index].copy()
    train.reset_index(drop=True, inplace=True)
    test.reset_index(drop=True, inplace=True)
    
    # standardization
    if standardize:
        scaler = StandardScaler()
        scaler.fit(train[X_columns])
        train[X_columns] = scaler.transform(train[X_columns])
        test[X_columns] = scaler.transform(test[X_columns])

    X_test = test[X_columns].copy()
    y_test = test[Y_column].copy()
    X_test_blue = X_test[test[G_column]==1]
    y_test_blue = y_test[test[G_column]==1].values.ravel()
    X_test_red = X_test[test[G_column]==0]
    y_test_red = y_test[test[G_column]==0].values.ravel()
    y_test = y_test.values.ravel()

    return X_test_blue, y_test_blue, X_test_red, y_test_red, train


def balance_train_size(train, G_column, X_columns, Y_column, ratio=1.0,
                       seed=42, verbose=False):
    """balance the size of the training sets for blue and red groups

    Parameters
    ----------
    train : pandas.DataFrame
        training set
    G_column : str
        group variable
    X_columns : list of str
        covariates
    Y_column : str
        outcome variable
    ratio : float, optional
        ratio of the large group to the small group, by default 1.0
        If None, we do not balance the training set size.
    seed : int, optional
        for randomization, by default 42

    Returns
    -------
    df_X_train_blue_full, df_X_train_red_full,
    df_y_train_blue_full, df_y_train_red_full
    """
    
    train_blue = train[train[G_column]==1].copy()
    train_red = train[train[G_column]==0].copy()

    blue_size = train_blue.shape[0]
    red_size = train_red.shape[0]

    if ratio is None:
        if verbose:
            print('no balancing')
    elif blue_size < red_size:
        size = int(blue_size*ratio)
        if verbose:
            print(f'blue < red, size={size}')
        random_rows = train_red.sample(n=size, replace=False,
                                       random_state=seed).index
        train_red = train_red.loc[random_rows].copy()
    else:
        size = int(red_size*ratio)
        if verbose:
            print(f'blue >= red, size={size}')
        random_rows = train_blue.sample(n=size, replace=False,
                                        random_state=seed).index
        train_blue = train_blue.loc[random_rows].copy()

    df_X_train_blue_full = train_blue[X_columns].copy()
    df_X_train_red_full = train_red[X_columns].copy()
    df_y_train_blue_full = train_blue[Y_column].copy()
    df_y_train_red_full = train_red[Y_column].copy()

    return df_X_train_blue_full, df_X_train_red_full,\
           df_y_train_blue_full, df_y_train_red_full
